# -*- coding: utf-8 -*-
"""
Created on Wed Mar 23 18:40:39 2022

@author: JM wrote the first kernel of the code to plot persistance plots and read in data. APR the rest
"""
"""The first kernel defines basic functions and reads in the data. 
The code is designed to profile switching of a motor or something similar using the signal from a PD
It is used to find the switching time.
It analyses vibrations present during and after the switch,  and can output their p2p height and frequency
It also gives a response time of the switch. 

It is designed to be used with three types of data. 1. A profile of the entire switching
2. A profile of vibrations/oscillations at the end of the switch
3. A profile of the very beginning of the switch to find a switch response time.

After runing the first kernel to read in the data, depending on what is present in the data, run 
one of the next three
"""
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
import matplotlib 
import os
import pandas as pd
import scipy.optimize
from numpy.random import rand,normal
import pickle
from scipy.stats import sem
from scipy.signal import find_peaks

def rotation_fit(t,t0,t1,A,w,phi,c):
    """Fit for rotation data
    """
    V =A*np.cos(w*(t-t0)+phi)+c
    V = np.where(t<t0,A*np.cos(phi)+c,V)
    V = np.where(t>t1,A*np.cos(w*(t1-t0)+phi)+c,V)

    return V

def persistance_plot(t_values,data_list,axes =None,vbins = None):
    """Plot a list of traces with same time base as a 2D histogram
    Imitates persistance mode of an oscilloscope
    Designed to to work with digitised data from picoscope
    """
    if not axes:
        plt.figure()
    if not vbins:
        vbins=len(np.unique(data_list[0]))
    print("Number of vertical bins",vbins)
    
    plt.hist2d(np.tile(t_values, len(data_list)), np.ravel(data_list), bins=(int(len(t_values)/4),vbins),norm= matplotlib.colors.LogNorm())
    plt.colorbar()
    



folder_path = r"Y:\Microscope\People\adarsh\Hollow stepper\Switch time\\"

measure = "400k_full\\"
path = folder_path+measure

V_Alist= []
V_Blist = []
i = 0
read = True

if read:
    # Read in data
    os.makedirs(os.path.dirname("testingdata\\"+measure),exist_ok=True)

    for file in os.listdir(path):
            print(file)

            if file.endswith(".mat"):
                
                matfile = scipy.io.loadmat(path+file)

                print(matfile.keys())
                tstart = matfile['Tstart'][0][0]
                # print(tstart)
                samples = matfile['RequestedLength'][0][0] + \
                    matfile['ExtraSamples'][0][0]
                times = np.linspace(tstart, tstart+samples *
                                    matfile['Tinterval'][0][0], samples)
                voltageB = np.array(matfile['B'])
                voltageA = np.array(matfile['A'])
                V_Alist.append(voltageA)
                V_Blist.append(voltageB)
                i+=1
                # plt.plot(times,voltageA)
                # plt.show()
                # if i>5:
                    # break
    with open("testingdata\\"+measure+"V_Adata.pkl",'wb') as f:
        pickle.dump(V_Alist,f)
    with open("testingdata\\"+measure+"V_Bdata.pkl",'wb') as f:
        pickle.dump(V_Blist,f)
    with open("testingdata\\"+measure+"times.pkl",'wb') as f:
        pickle.dump(times,f)


with open("testingdata\\"+measure+"V_Adata.pkl",'rb') as f:
    V_Alist = pickle.load(f)
with open("testingdata\\"+measure+"times.pkl",'rb') as f:
    times = pickle.load(f)
with open("testingdata\\"+measure+"V_Bdata.pkl",'rb') as f:
    V_Blist = pickle.load(f)
print(len(V_Alist))



trace_a=np.array(V_Alist)
trace_b=np.array(V_Blist)
#trace_a,trace_b is an np array of the traces 

figurepath = "figures\\"+measure
os.makedirs(os.path.dirname(figurepath),exist_ok=True)

#%%



"""
Run this to get a switch time with an input of the full switching profile
"""
trace_a=np.array(V_Alist)+0.16
trace_b=np.array(V_Blist) +0.16
persistance_plot(times*1000,1000*trace_a)
plt.xlabel('Time (ms)')
plt.ylabel('Voltage in PD A (mV)')
plt.show()
groups = 80 #data is averaged to cancel out some high frequency noise we see
start=times[0]
s_times=[]
s_height=[]
avg_end=[]
stop=np.max(times)
interval=times[2]-times[1]
p2_times=[]
for k in np.arange(0,len(trace_a)):
    volt_bin=[]
    
    for i in np.arange(0,len(trace_a[k,:])-groups,groups):
        volt_bin.append(np.mean(trace_a[k,i:i+groups]))
    volt_bin=np.array(volt_bin)
    peaks=find_peaks(-1*volt_bin,prominence=0.002)
    #print(peaks)
    plt.plot(np.arange(0,len(volt_bin))*interval*groups*1000 +start*1000,volt_bin,'r')
    for peak in peaks[0]:
        print(peak)
        plt.plot([peak*interval*groups*1000 +start*1000,1000*peak*interval*groups+1000*start],[-1,1],'b')
        plt.ylim(0,1.2*np.max(trace_a)) 
    time=interval*(peaks[0][0]-1)*groups+start           
        #print(time) 
    s_times.append(time)
    p2_times.append(interval*(peaks[0][1]-1)*groups+start )
    s_height.append(volt_bin[peaks[0][0]])
    avg_end.append(np.mean(np.array(volt_bin[len(volt_bin)-11:len(volt_bin)-1])))
    plt.xlabel('Time(ms)')
    plt.ylabel('PD Voltage')
    plt.ylim(0,1.1*np.max(trace_a))
    plt.show()
    
   
print('Mean switch time is ', np.mean(s_times),((np.std(s_times)**2)+((interval*groups)**2))**0.5)
print('Time resolution of time averaged data is',interval*groups,'s')
p2p=2*(np.array(avg_end)-np.array(s_height))
print('Peak to Peak height in mV is',1000*np.mean(p2p),1000*np.std(p2p))
print('Time period of 1 oscillation is',np.mean(p2_times)-np.mean(s_times),np.std(np.mean(p2_times)-np.mean(s_times)))
#%%
plt.plot(np.arange(0,len(volt_bin))*interval*groups*1000 +start*1000,volt_bin,color='r')
plt.plot(np.arange(0,len(volt_bin))*interval*groups*1000 +start*1000,.31-volt_bin,color='b')
for peak in peaks[0]:
    print(peak)
    plt.plot([peak*interval*groups*1000 +start*1000,1000*peak*interval*groups+1000*start],[-1,1],color='black')
    plt.ylim(0,1.1*np.max(trace_a)) 
p2_times.append(interval*(peaks[0][1]-1)*groups+start )
s_height.append(volt_bin[peaks[0][0]])
avg_end.append(np.mean(np.array(volt_bin[len(volt_bin)-11:len(volt_bin)-1])))
plt.xlabel('Time(ms)')
plt.ylabel('PD Voltage')
plt.ylim(0,1.01*np.max(trace_a))
plt.legend(['Transmitted by PBS','Reflected by PBS'],framealpha=1)
plt.show()
#%%
'''
Run on oscillations at the end of the trace
Uses a ft to output the frequency
'''
from scipy.fft import fft,fftfreq
groups = 1
start=times[0]
stop=np.max(times)
interval=times[2]-times[1]
volt_bin=[]
for i in np.arange(0,len(trace_a[1,:])-groups,groups):
    volt_bin.append(np.mean(trace_a[2,i:i+groups]))
freq=fft(volt_bin[10000:])
N = len(volt_bin[10000:])

# sample spacing
T = interval*groups
xf = fftfreq(N, T)[:N//2]
yf=2.0/N * np.abs(freq[0:N//2])
plt.plot(xf,yf)
plt.xlim(0,300)
plt.ylim(0,0.004)
fpeaks=find_peaks(yf[0:50],prominence=0.00005)
for fpeak in fpeaks[0]:
    print(fpeak)
    plt.plot([xf[fpeak],xf[fpeak]],[0,1])
print(xf[fpeak],xf[fpeak]-xf[fpeak-1])
print(yf[fpeak])
    #%%
'''
Run this to get a response time for the switch. 
May have to change >1.005 to <0.995 and vice versa(rising or falling trace)

'''
groups = 70
start=times[0]
interval=times[2]-times[1]
volt_bin=[]
rtime_list=[]
for k in np.arange(0,len(trace_a)):
    volt_bin=[]
    for i in np.arange(0,len(trace_a[k,:])-groups,groups):
        volt_bin.append(0.11+np.mean(trace_a[k,i:i+groups]))
    volt_bin=np.array(volt_bin)
    plt.plot(np.arange(0,len(volt_bin))*interval*groups + start,volt_bin)
    v_init=np.mean(volt_bin[0:30])
    for k in np.arange(0,len(volt_bin)):
        if volt_bin[k]>1.005*v_init:
            rtime=interval*(k-1)*groups+start 
            rtime_list.append(rtime)
            #print(rtime)
            plt.plot([rtime,rtime],[-1,1])
            plt.ylim(0.9*np.min(volt_bin),.22)
            break
print(np.mean(rtime_list)*1000,np.std(rtime_list)*1000)      

#%%
groups=50
start=times[0]
for k in np.arange(0,len(trace_a)):
    volt_bin=[]
    for i in np.arange(3000,len(trace_a[k,:])-groups,groups):
        volt_bin.append(0.11+np.mean(trace_a[k,i:i+groups]))
    volt_bin=np.array(volt_bin)
    plt.plot(np.arange(0,len(volt_bin))*interval*groups + start,volt_bin)
    peaks=find_peaks(-1*volt_bin,prominence=0.0009)
    for peak in peaks[0]:
        print(peak)
        plt.plot([(peak)*interval*groups+start,(peak)*interval*groups+start],[-1,1])
        plt.ylim(0.2,0.3) 
        #plt.xlim(0.03,0.07)